#define BOOST_TEST_MODULE sharing

#include <asterisk/sharing.h>
#include <utils/types.h>

#include <boost/test/data/monomorphic.hpp>
#include <boost/test/data/test_case.hpp>
#include <boost/test/included/unit_test.hpp>
#include <random>
#include <vector>

using namespace asterisk;
namespace bdata = boost::unit_test::data;

struct GlobalFixture {
  GlobalFixture() {
    NTL::ZZ_p::init(NTL::conv<NTL::ZZ>("17816577890427308801"));
  }
};

BOOST_GLOBAL_FIXTURE(GlobalFixture);

constexpr int TEST_DATA_MAX_VAL = 1000;
constexpr int NUM_SAMPLES = 1;
std::random_device rd;
std::mt19937 engine(rd());
std::uniform_int_distribution<uint64_t> distrib;
// common::utils::Field MAC_key = distrib(engine);

// Utility function to generate replicated secret sharing of 3 parties.
std::vector<AuthAddShare<common::utils::Field>> generateAuthAddShares(common::utils::Field secret, size_t nP) {
  common::utils::Field MAC_key = common::utils::Field(5);
  std::random_device rd;
  std::mt19937 engine(rd());
  std::uniform_int_distribution<uint64_t> distrib;
  
  
  common::utils::Field tag = secret * MAC_key;

  std::vector<common::utils::Field> key_shares(nP);
  std::vector<common::utils::Field> values(nP);
  std::vector<common::utils::Field> tags(nP);

  common::utils::Field sum1 = common::utils::Field(0), sum2 = common::utils::Field(0), sum3 = common::utils::Field(0);
  for (int i = 0; i < nP-1; ++i) {
    key_shares[i] = Field(distrib(engine));
    values[i] = Field(distrib(engine));
    tags[i] = Field(distrib(engine));

    sum1 += key_shares[i];
    sum2 += values[i];
    sum3 += tags[i];
  }
  key_shares[nP-1] = MAC_key - sum1;
  values[nP-1] = secret - sum2;
  tags[nP-1] = tag - sum3;

  std::vector<AuthAddShare<common::utils::Field>> AAS;
  for(size_t i = 0; i < nP ; i++) {
    AuthAddShare<common::utils::Field> temp(key_shares[i], values[i], tags[i]);
    AAS.push_back(temp);
  }

  return AAS;

}

// Utility function to reconstruct secret from shares as generated by
// generateReplicatedShares function.

common::utils::Field reconstructAuthAddShares(
    const std::vector<AuthAddShare<common::utils::Field>>& v_aas, size_t nP) {
  common::utils::Field secret = common::utils::Field(0);
  common::utils::Field tag = common::utils::Field(0);
  common::utils::Field key = common::utils::Field(0);
      for(size_t i = 1; i <= nP; i++) {
        secret += v_aas[i-1].valueAt();
        tag += v_aas[i-1].tagAt();
        key += v_aas[i-1].keySh();
      }
      if(secret * key == tag) { return secret; }
      else {
        std::cout<< "Incorrect sharing !!!" << std::endl;
        return common::utils::Field(0);
      }
  }

BOOST_AUTO_TEST_SUITE(authenticated_additive_sharing)

BOOST_DATA_TEST_CASE(reconstruction,
                     bdata::random(0, TEST_DATA_MAX_VAL) ^
                         bdata::xrange(NUM_SAMPLES),
                     secret_val, idx) {
  size_t nP = 4;
  common::utils::Field secret = Field(secret_val);

  auto v_aas = generateAuthAddShares(secret, nP);
  
  auto recon_value = reconstructAuthAddShares(v_aas, nP);
    
  BOOST_TEST(recon_value == secret);
}

BOOST_DATA_TEST_CASE(share_arithmetic,
                     bdata::random(0, TEST_DATA_MAX_VAL) ^
                         bdata::random(0, TEST_DATA_MAX_VAL) ^
                         bdata::xrange(NUM_SAMPLES),
                     vala, valb, idx) {
  size_t nP = 4;
  common::utils::Field a = Field(vala);
  common::utils::Field b = Field(valb);
  auto v_aas_a = generateAuthAddShares(a, nP);
  auto v_aas_b = generateAuthAddShares(b, nP);

  std::vector<AuthAddShare<common::utils::Field>> v_aas_c(nP);

  for (size_t i = 0; i < nP; ++i) {
    // This implicitly checks compound assignment operators too.
    v_aas_c[i] = v_aas_a[i] + v_aas_b[i];
  }

  auto sum = reconstructAuthAddShares(v_aas_c, nP);

  // std::cout << sum <<"\t" << a + b <<"\n";
  BOOST_TEST(sum == a + b);
  
  for (size_t i = 0; i < nP; ++i) {
    // This implicitly checks compound assignment operators too.
    v_aas_c[i] = v_aas_a[i] - v_aas_b[i];
  }

  auto difference = reconstructAuthAddShares(v_aas_c, nP);
  // std::cout << difference <<"\t" << a - b <<"\n";
  BOOST_TEST(difference == a - b);
  
}

BOOST_DATA_TEST_CASE(share_const_arithmetic,
                     bdata::random(0, TEST_DATA_MAX_VAL) ^
                         bdata::random(0, TEST_DATA_MAX_VAL) ^
                         bdata::xrange(NUM_SAMPLES),
                     secret_val, const_val, idx) {
  size_t nP = 6;
  // common::utils::Field secret = secret_val;
  // common::utils::Field constant = const_val;
  common::utils::Field secret = Field(100);
  common::utils::Field constant = Field(200);
  auto v_aas = generateAuthAddShares(secret, nP);

  std::vector<AuthAddShare<common::utils::Field>> v_aas_res(nP);
  for (size_t i = 0; i < nP; ++i) {
    // This implicitly checks compound assignment operators too.
    v_aas_res[i] = v_aas[i].add(constant, i);
  }

  auto sum = reconstructAuthAddShares(v_aas_res, nP);
  //std::cout << product <<"\t" << secret * constant <<"\n";
  BOOST_TEST(sum == secret + constant);

}


BOOST_DATA_TEST_CASE(const_addition,
                     bdata::random(0, TEST_DATA_MAX_VAL) ^
                         bdata::random(0, TEST_DATA_MAX_VAL) ^
                         bdata::xrange(NUM_SAMPLES),
                     secret_val, const_val, idx) {
  size_t nP = 6;
  // common::utils::Field secret = secret_val;
  // common::utils::Field constant = const_val;
  common::utils::Field secret = Field(100);
  common::utils::Field constant = Field(200);
  auto v_aas = generateAuthAddShares(secret, nP);

  std::vector<AuthAddShare<common::utils::Field>> v_aas_res(nP);
  for (size_t i = 0; i < nP; ++i) {
    // This implicitly checks compound assignment operators too.
    v_aas_res[i] = v_aas[i] * constant;
  }

  auto product = reconstructAuthAddShares(v_aas_res, nP);
  //std::cout << product <<"\t" << secret * constant <<"\n";
  BOOST_TEST(product == secret * constant);
  
  
}


BOOST_AUTO_TEST_SUITE_END()
